/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
 * Description: all-reduce fusion header file
 * Author: lilianlin
 * Create: 2019-12-8
 */

#ifndef HCOM_ALL_REDUCE_FUSION_H
#define HCOM_ALL_REDUCE_FUSION_H

#include "hccl/base.h"
#include "common/optimizer/graph_optimizer.h"
#include "common/optimizer/graph_optimizer_types.h"
#include "graph/compute_graph.h"
#include "op_fusion_base_pub.h"
#include "platform/platform_info.h"
#include "nlohmann/json.hpp"

namespace hccl {

class HcomAllReduceFusion : public OpFusionBase {
public:
    HcomAllReduceFusion();
    ~HcomAllReduceFusion() override;
    HcclResult Run(ge::ComputeGraph& graph) override;
protected:
    virtual HcclResult FuseOps(ge::ComputeGraph& graph, FusionSection& fusionSection);
    virtual HcclResult GetGradSplitStrategy(const std::string& modelName, const std::string& sGroup, \
        const FusionSection& fusionSection, u32& segmentNum, std::vector<u32>& segmentIndex);
    HcclResult GetFusionOps(ge::ComputeGraph& graph, FusionInfos& fusionOps);
    HcclResult GetFusionOpInfo(ge::NodePtr& nodePtr, FusionInfos& fusionOps);
    HcclResult GetFusionOption(const ge::NodePtr& nodePtr, FusionOption &fusionOption);
    HcclResult GetFusionStrategy(const ge::ComputeGraph& graph, const FusionSection& fusionSection, \
        u32& segmentNum, std::vector<u32>& segmentIndex);
    HcclResult GetNodeUnknownShapeInfo(ge::NodePtr& nodePtr, bool &bUnknownShapeNode);
    HcclResult AddHcclFusionFlag(ge::OpDescPtr& opDescPtr);
    HcclResult GenerateFusionLabel(const FusionOption &fusionOption, std::string &fusionLabel);
    HcclResult GetFusionInformation(const ge::ComputeGraph& graph, std::string &fusionHash);
    HcclResult CalculateSegmentIndex(std::string& fusionHash, u64 tensorLimit, std::vector<u32>& segmentIndex);
    HcclResult GetPathFromDefault(std::string &fusionPath);
    HcclResult GetInformationFromLibrary(std::string &fusionPath, std::string& fusionHash, \
        u64 tensorLimit, std::vector<u32>& segmentIndex);
    HcclResult GetInfoFromContentedLibrary(std::string fusionPath, std::string& fusionHash, \
        u64 tensorLimit, std::vector<u32>& segmentIndex);
    HcclResult GetFusionOpsSlices(FusionInfos& fusionInfos, FusionInfos& fusionInfosTemp);

private:
    bool bHasUnknownShapeNodeGraph_;
    bool unknownShapeOriginalGraph_;
    std::string fusionHash_;
    uint32_t modelGraphId;
    u64 tensorFusionLimit_;
};
}
#endif
